package net.j4love.mybatis.kit.sql;

import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.visitor.SchemaStatVisitor;
import com.alibaba.druid.stat.TableStat;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.mapping.SqlCommandType;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class SQLParserUtils {

    public static List<TableCommand> getTables(String sql , String dbType) {
        List<SQLStatement> sqlStatements = SQLUtils.parseStatements(sql, dbType);
        SchemaStatVisitor visitor = SQLUtils.createSchemaStatVisitor(dbType);
        Map<String , String> tableStatMap = new HashMap<>(4);
        for (SQLStatement sqlStatement : sqlStatements) {
            sqlStatement.accept(visitor);
            Map<TableStat.Name, TableStat> tables = visitor.getTables();
            for (Map.Entry<TableStat.Name, TableStat> entry : tables.entrySet()) {
                String tableName = entry.getKey().getName();
                String oldVal = StringUtils.defaultString(tableStatMap.get(tableName) , "");
                tableStatMap.put(tableName , oldVal + " ; " + entry.getValue().toString());
            }
        }
        List<TableCommand> tableCommands = new ArrayList<>(tableStatMap.size());
        for (Map.Entry<String, String> entry : tableStatMap.entrySet()) {

            TableCommand.TableCommandBuilder builder = TableCommand.builder().tableName(entry.getKey());
            String val = entry.getValue().toUpperCase();
            boolean isUpdateCommand = val.contains(SqlCommandType.INSERT.name()) || val.contains(SqlCommandType.UPDATE.name())
                    || val.contains(SqlCommandType.DELETE.name());
            if (isUpdateCommand) {
                builder.sqlCommandType(SqlCommandType.UPDATE);
            } else {
                builder.sqlCommandType(SqlCommandType.SELECT);
            }
            tableCommands.add(builder.build());
        }

        return tableCommands;
    }

    public static Set<String> getTableNames(String sql , String dbType) {
        List<SQLStatement> sqlStatements = SQLUtils.parseStatements(sql, dbType);
        SchemaStatVisitor visitor = SQLUtils.createSchemaStatVisitor(dbType);
        Set<String> tableNames = new HashSet<>(4);
        for (SQLStatement sqlStatement : sqlStatements) {
            sqlStatement.accept(visitor);
            Map<TableStat.Name, TableStat> tables = visitor.getTables();
            for (TableStat.Name tableName : tables.keySet()) {
                tableNames.add(tableName.getName());
            }
        }
        return tableNames;
    }

}
